import math
from typing import List, Tuple
import cv2

import numpy as np
import maxflow


def three_point_curvature(p1, p2, p3):
    x1, y1 = p1
    x2, y2 = p2
    x3, y3 = p3
    dx23 = x2 - x3
    dy23 = y2 - y3
    dx13 = x1 - x3
    dy13 = y1 - y3
    dx12 = x1 - x2
    dy12 = y1 - y2
    num = 2 * abs(
        x1 * (y2 - y3) +
        x2 * (y3 - y1) +
        x3 * (y1 - y2)
    )
    den = math.sqrt(dx23**2 + dy23**2) \
        * math.sqrt(dx13**2 + dy13**2) \
        * math.sqrt(dx12**2 + dy12**2)
    return num / den if den > 0 else 0.0


def split_strokes(strokes, W, theta0, rho, l0, beta, r, sigma):
    new_strokes: List[np.ndarray] = []

    for pts in strokes:
        n = pts.shape[0]
        if n < 2 * W + 1:
            new_strokes.append(pts.copy())
            continue

        ell = 2 * W
        theta_thr = theta0 + rho * math.exp(-ell / l0)

        candidates: List[int] = []
        kappas: List[float] = []
        coords: List[np.ndarray] = []

        for i in range(W, n - W):
            p_prev = pts[i - W]
            p_cur = pts[i]
            p_next = pts[i + W]

            v1 = p_prev - p_cur
            v2 = p_next - p_cur
            norm1 = np.linalg.norm(v1)
            norm2 = np.linalg.norm(v2)
            if norm1 == 0 or norm2 == 0:
                continue
            cosang = np.dot(v1, v2) / (norm1 * norm2)
            cosang = max(-1.0, min(1.0, cosang))
            theta = math.acos(cosang)

            if theta < theta_thr:
                kappa = three_point_curvature(p_prev, p_cur, p_next)
                candidates.append(i)
                kappas.append(kappa)
                coords.append(p_cur)

        m = len(candidates)
        if m == 0:
            new_strokes.append(pts.copy())
            continue

        g = maxflow.Graph[float](m, m * 4)
        nodes = g.add_nodes(m)

        for idx, kappa in enumerate(kappas):
            lam = math.exp(-beta * kappa)
            g.add_tedge(idx, lam, 0.0)

        for i in range(m):
            for j in range(i + 1, m):
                pi = coords[i]
                pj = coords[j]
                dist = np.linalg.norm(pi - pj)
                if dist < r:
                    phi = math.exp(-dist / sigma)
                    g.add_edge(i, j, phi, phi)

        g.maxflow()
        labels = [g.get_segment(i) for i in range(m)]

        split_idxs = sorted(candidates[i] for i, lb in enumerate(labels) if lb == 1)

        if not split_idxs:
            new_strokes.append(pts.copy())
        else:
            start = 0
            for idx in split_idxs:
                new_strokes.append(pts[start:idx + 1].copy())
                start = idx
            if start < n:
                new_strokes.append(pts[start:].copy())

    return new_strokes
